import copy
import hashlib
from collections import deque
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np
import scipy

from ray.rllib.core import DEFAULT_AGENT_ID, DEFAULT_MODULE_ID
from ray.rllib.env.single_agent_episode import SingleAgentEpisode
from ray.rllib.env.utils.infinite_lookback_buffer import InfiniteLookbackBuffer
from ray.rllib.utils import force_list
from ray.rllib.utils.annotations import (
    OverrideToImplementCustomLogic_CallToSuperRecommended,
    override,
)
from ray.rllib.utils.metrics import (
    ACTUAL_N_STEP,
    AGENT_ACTUAL_N_STEP,
    AGENT_STEP_UTILIZATION,
    ENV_STEP_UTILIZATION,
    MODULE_ACTUAL_N_STEP,
    MODULE_STEP_UTILIZATION,
    NUM_AGENT_EPISODES_ADDED,
    NUM_AGENT_EPISODES_ADDED_LIFETIME,
    NUM_AGENT_EPISODES_EVICTED,
    NUM_AGENT_EPISODES_EVICTED_LIFETIME,
    NUM_AGENT_EPISODES_PER_SAMPLE,
    NUM_AGENT_EPISODES_STORED,
    NUM_AGENT_RESAMPLES,
    NUM_AGENT_STEPS_ADDED,
    NUM_AGENT_STEPS_ADDED_LIFETIME,
    NUM_AGENT_STEPS_EVICTED,
    NUM_AGENT_STEPS_EVICTED_LIFETIME,
    NUM_AGENT_STEPS_PER_SAMPLE,
    NUM_AGENT_STEPS_PER_SAMPLE_LIFETIME,
    NUM_AGENT_STEPS_SAMPLED,
    NUM_AGENT_STEPS_SAMPLED_LIFETIME,
    NUM_AGENT_STEPS_STORED,
    NUM_ENV_STEPS_ADDED,
    NUM_ENV_STEPS_ADDED_LIFETIME,
    NUM_ENV_STEPS_EVICTED,
    NUM_ENV_STEPS_EVICTED_LIFETIME,
    NUM_ENV_STEPS_PER_SAMPLE,
    NUM_ENV_STEPS_PER_SAMPLE_LIFETIME,
    NUM_ENV_STEPS_SAMPLED,
    NUM_ENV_STEPS_SAMPLED_LIFETIME,
    NUM_ENV_STEPS_STORED,
    NUM_EPISODES_ADDED,
    NUM_EPISODES_ADDED_LIFETIME,
    NUM_EPISODES_EVICTED,
    NUM_EPISODES_EVICTED_LIFETIME,
    NUM_EPISODES_PER_SAMPLE,
    NUM_EPISODES_STORED,
    NUM_MODULE_EPISODES_ADDED,
    NUM_MODULE_EPISODES_ADDED_LIFETIME,
    NUM_MODULE_EPISODES_EVICTED,
    NUM_MODULE_EPISODES_EVICTED_LIFETIME,
    NUM_MODULE_EPISODES_PER_SAMPLE,
    NUM_MODULE_EPISODES_STORED,
    NUM_MODULE_RESAMPLES,
    NUM_MODULE_STEPS_ADDED,
    NUM_MODULE_STEPS_ADDED_LIFETIME,
    NUM_MODULE_STEPS_EVICTED,
    NUM_MODULE_STEPS_EVICTED_LIFETIME,
    NUM_MODULE_STEPS_PER_SAMPLE,
    NUM_MODULE_STEPS_PER_SAMPLE_LIFETIME,
    NUM_MODULE_STEPS_SAMPLED,
    NUM_MODULE_STEPS_SAMPLED_LIFETIME,
    NUM_RESAMPLES,
)
from ray.rllib.utils.metrics.metrics_logger import MetricsLogger
from ray.rllib.utils.replay_buffers.base import ReplayBufferInterface
from ray.rllib.utils.typing import AgentID, ModuleID, ResultDict, SampleBatchType


class EpisodeReplayBuffer(ReplayBufferInterface):
    """Buffer that stores (completed or truncated) episodes by their ID.

    Each "row" (a slot in a deque) in the buffer is occupied by one episode. If an
    incomplete episode is added to the buffer and then another chunk of that episode is
    added at a later time, the buffer will automatically concatenate the new fragment to
    the original episode. This way, episodes can be completed via subsequent `add`
    calls.

    Sampling returns batches of size B (number of "rows"), where each row is a
    trajectory of length T. Each trajectory contains consecutive timesteps from an
    episode, but might not start at the beginning of that episode. Should an episode end
    within such a trajectory, a random next episode (starting from its t0) will be
    concatenated to that "row". Example: `sample(B=2, T=4)` ->

       0 .. 1 .. 2 .. 3  <- T-axis
    0 e5   e6   e7   e8
    1 f2   f3   h0   h2
    ^ B-axis

    .. where e, f, and h are different (randomly picked) episodes, the 0-index (e.g. h0)
    indicates the start of an episode, and `f3` is an episode end (gym environment
    returned terminated=True or truncated=True).

    0-indexed returned timesteps contain the reset observation, a dummy 0.0 reward, as
    well as the first action taken in the episode (action picked after observing
    obs(0)).
    The last index in an episode (e.g. f3 in the example above) contains the final
    observation of the episode, the final reward received, a dummy action
    (repeat the previous action), as well as either terminated=True or truncated=True.
    """

    __slots__ = (
        "capacity",
        "batch_size_B",
        "batch_length_T",
        "episodes",
        "episode_id_to_index",
        "num_episodes_evicted",
        "_indices",
        "_num_timesteps",
        "_num_timesteps_added",
        "sampled_timesteps",
        "rng",
    )

    def __init__(
        self,
        capacity: int = 10000,
        *,
        batch_size_B: int = 16,
        batch_length_T: int = 64,
        metrics_num_episodes_for_smoothing: int = 100,
        **kwargs,
    ):
        """Initializes an EpisodeReplayBuffer instance.

        Args:
            capacity: The total number of timesteps to be storable in this buffer.
                Will start ejecting old episodes once this limit is reached.
            batch_size_B: The number of rows in a SampleBatch returned from `sample()`.
            batch_length_T: The length of each row in a SampleBatch returned from
                `sample()`.
        """
        self.capacity = capacity
        self.batch_size_B = batch_size_B
        self.batch_length_T = batch_length_T

        # The actual episode buffer. We are using a deque here for faster insertion
        # (left side) and eviction (right side) of data.
        self.episodes = deque()
        # Maps (unique) episode IDs to the index under which to find this episode
        # within our `self.episodes` deque.
        # Note that even after eviction started, the indices in here will NOT be
        # changed. We will therefore need to offset all indices in
        # `self.episode_id_to_index` by the number of episodes that have already been
        # evicted (self._num_episodes_evicted) in order to get the actual index to use
        # on `self.episodes`.
        self.episode_id_to_index = {}
        # The number of episodes that have already been evicted from the buffer
        # due to reaching capacity.
        self._num_episodes_evicted = 0

        # List storing all index tuples: (eps_idx, ts_in_eps_idx), where ...
        # `eps_idx - self._num_episodes_evicted' is the index into self.episodes.
        # `ts_in_eps_idx` is the timestep index within that episode
        #  (0 = 1st timestep, etc..).
        # We sample uniformly from the set of these indices in a `sample()`
        # call.
        self._indices = []

        # The size of the buffer in timesteps.
        self._num_timesteps = 0
        # The number of timesteps added thus far.
        self._num_timesteps_added = 0

        # How many timesteps have been sampled from the buffer in total?
        self.sampled_timesteps = 0

        self.rng = np.random.default_rng(seed=None)

        # Initialize the metrics.
        self.metrics = MetricsLogger()
        self._metrics_num_episodes_for_smoothing = metrics_num_episodes_for_smoothing

        # Initialize the metrics.
        self.metrics = MetricsLogger()
        self._metrics_num_episodes_for_smoothing = metrics_num_episodes_for_smoothing

    @override(ReplayBufferInterface)
    def __len__(self) -> int:
        return self.get_num_timesteps()

    @override(ReplayBufferInterface)
    def add(self, episodes: Union[List["SingleAgentEpisode"], "SingleAgentEpisode"]):
        """Converts incoming SampleBatch into a number of SingleAgentEpisode objects.

        Then adds these episodes to the internal deque.
        """
        episodes = force_list(episodes)

        # Set up some counters for metrics.
        num_env_steps_added = 0
        agent_to_num_steps_added = {DEFAULT_AGENT_ID: 0}
        module_to_num_steps_added = {DEFAULT_MODULE_ID: 0}
        num_episodes_added = 0
        agent_to_num_episodes_added = {DEFAULT_AGENT_ID: 0}
        module_to_num_episodes_added = {DEFAULT_MODULE_ID: 0}
        num_episodes_evicted = 0
        agent_to_num_episodes_evicted = {DEFAULT_AGENT_ID: 0}
        module_to_num_episodes_evicted = {DEFAULT_MODULE_ID: 0}
        num_env_steps_evicted = 0
        agent_to_num_steps_evicted = {DEFAULT_AGENT_ID: 0}
        module_to_num_steps_evicted = {DEFAULT_MODULE_ID: 0}

        for eps in episodes:
            # Make sure we don't change what's coming in from the user.
            # TODO (sven): It'd probably be better to make sure in the EnvRunner to not
            #  hold on to episodes (for metrics purposes only) that we are returning
            #  back to the user from `EnvRunner.sample()`. Then we wouldn't have to
            #  do any copying. Instead, either compile the metrics right away on the
            #  EnvRunner OR compile metrics entirely on the Algorithm side (this is
            #  actually preferred).
            eps = copy.deepcopy(eps)

            eps_len = len(eps)
            # TODO (simon): Check, if we can deprecate these two
            # variables and instead peek into the metrics.
            self._num_timesteps += eps_len
            self._num_timesteps_added += eps_len
            num_env_steps_added += eps_len
            agent_to_num_steps_added[DEFAULT_AGENT_ID] += eps_len
            module_to_num_steps_added[DEFAULT_MODULE_ID] += eps_len

            # Ongoing episode, concat to existing record.
            if eps.id_ in self.episode_id_to_index:
                eps_idx = self.episode_id_to_index[eps.id_]
                existing_eps = self.episodes[eps_idx - self._num_episodes_evicted]
                old_len = len(existing_eps)
                self._indices.extend([(eps_idx, old_len + i) for i in range(len(eps))])
                existing_eps.concat_episode(eps)
            # New episode. Add to end of our episodes deque.
            else:
                num_episodes_added += 1
                agent_to_num_episodes_added[DEFAULT_AGENT_ID] += 1
                module_to_num_episodes_added[DEFAULT_MODULE_ID] += 1
                self.episodes.append(eps)
                eps_idx = len(self.episodes) - 1 + self._num_episodes_evicted
                self.episode_id_to_index[eps.id_] = eps_idx
                self._indices.extend([(eps_idx, i) for i in range(len(eps))])

            # Eject old records from front of deque (only if we have more than 1 episode
            # in the buffer).
            while self._num_timesteps > self.capacity and self.get_num_episodes() > 1:
                # Eject oldest episode.
                evicted_eps = self.episodes.popleft()
                evicted_eps_len = len(evicted_eps)
                num_episodes_evicted += 1
                num_env_steps_evicted += evicted_eps_len
                agent_to_num_episodes_evicted[DEFAULT_AGENT_ID] += 1
                module_to_num_episodes_evicted[DEFAULT_MODULE_ID] += 1
                agent_to_num_steps_evicted[
                    DEFAULT_AGENT_ID
                ] += evicted_eps.agent_steps()
                module_to_num_steps_evicted[
                    DEFAULT_MODULE_ID
                ] += evicted_eps.agent_steps()
                # Correct our size.
                self._num_timesteps -= evicted_eps_len

                # Erase episode from all our indices:
                # 1) Main episode index.
                evicted_idx = self.episode_id_to_index[evicted_eps.id_]
                del self.episode_id_to_index[evicted_eps.id_]
                # 2) All timestep indices that this episode owned.
                new_indices = []  # New indices that will replace self._indices.
                idx_cursor = 0
                # Loop through all (eps_idx, ts_in_eps_idx)-tuples.
                for i, idx_tuple in enumerate(self._indices):
                    # This tuple is part of the evicted episode -> Add everything
                    # up until here to `new_indices` (excluding this very index, b/c
                    # it's already part of the evicted episode).
                    if idx_cursor is not None and idx_tuple[0] == evicted_idx:
                        new_indices.extend(self._indices[idx_cursor:i])
                        # Set to None to indicate we are in the eviction zone.
                        idx_cursor = None
                    # We are/have been in the eviction zone (i pointing/pointed to the
                    # evicted episode) ..
                    elif idx_cursor is None:
                        # ... but are now not anymore (i is now an index into a
                        # non-evicted episode) -> Set cursor to valid int again.
                        if idx_tuple[0] != evicted_idx:
                            idx_cursor = i
                            # But early-out if evicted episode was only 1 single
                            # timestep long.
                            if evicted_eps_len == 1:
                                break
                        # Early-out: We reached the end of the to-be-evicted episode.
                        # We can stop searching further here (all following tuples
                        # will NOT be in the evicted episode).
                        elif idx_tuple[1] == evicted_eps_len - 1:
                            assert self._indices[i + 1][0] != idx_tuple[0]
                            idx_cursor = i + 1
                            break

                # Jump over (splice-out) the evicted episode if we are still in the
                # eviction zone.
                if idx_cursor is not None:
                    new_indices.extend(self._indices[idx_cursor:])

                # Reset our `self._indices` to the newly compiled list.
                self._indices = new_indices

                # Increase episode evicted counter.
                self._num_episodes_evicted += 1

        # Update the metrics.
        self._update_add_metrics(
            num_episodes_added=num_episodes_added,
            num_env_steps_added=num_env_steps_added,
            num_episodes_evicted=num_episodes_evicted,
            num_env_steps_evicted=num_env_steps_evicted,
            agent_to_num_episodes_added=agent_to_num_episodes_added,
            agent_to_num_steps_added=agent_to_num_steps_added,
            agent_to_num_episodes_evicted=agent_to_num_episodes_evicted,
            agent_to_num_steps_evicted=agent_to_num_steps_evicted,
            module_to_num_episodes_added=module_to_num_steps_added,
            module_to_num_steps_added=module_to_num_episodes_added,
            module_to_num_episodes_evicted=module_to_num_episodes_evicted,
            module_to_num_steps_evicted=module_to_num_steps_evicted,
        )

    @OverrideToImplementCustomLogic_CallToSuperRecommended
    def _update_add_metrics(
        self,
        *,
        num_episodes_added: int,
        num_env_steps_added: int,
        num_episodes_evicted: int,
        num_env_steps_evicted: int,
        agent_to_num_episodes_added: Dict[AgentID, int],
        agent_to_num_steps_added: Dict[AgentID, int],
        agent_to_num_episodes_evicted: Dict[AgentID, int],
        agent_to_num_steps_evicted: Dict[AgentID, int],
        module_to_num_steps_added: Dict[ModuleID, int],
        module_to_num_episodes_added: Dict[ModuleID, int],
        module_to_num_episodes_evicted: Dict[ModuleID, int],
        module_to_num_steps_evicted: Dict[ModuleID, int],
        **kwargs,
    ) -> None:
        """Updates the replay buffer's adding metrics.

        Args:
            num_episodes_added: The total number of episodes added to the
                buffer in the `EpisodeReplayBuffer.add` call.
            num_timesteps_added: The total number of environment steps added to the
                buffer in the `EpisodeReplayBuffer.add` call.
            num_episodes_evicted: The total number of environment steps evicted from
                the buffer in the `EpisodeReplayBuffer.add` call. Note, this
                does not include the number of episodes evicted before ever
                added to the buffer (i.e. can happen in case a lot of episodes
                were added and the buffer's capacity is not large enough).
            num_env_steps_evicted: he total number of environment steps evicted from
                the buffer in the `EpisodeReplayBuffer.add` call. Note, this
                does not include the number of steps evicted before ever
                added to the buffer (i.e. can happen in case a lot of episodes
                were added and the buffer's capacity is not large enough).
            agent_to_num_episodes_added: A dictionary with the number of episodes per
                agent ID added to the buffer during the `EpisodeReplayBuffer.add` call.
            agent_to_num_steps_added: A dictionary with the number of agent steps per
                agent ID added to the buffer during the `EpisodeReplayBuffer.add` call.
            agent_to_num_episodes_evicted: A dictionary with the number of episodes per
                agent ID evicted to the buffer during the `EpisodeReplayBuffer.add` call.
            agent_to_num_steps_evicted: A dictionary with the number of agent steps per
                agent ID evicted to the buffer during the `EpisodeReplayBuffer.add` call.
            module_to_num_episodes_added: A dictionary with the number of episodes per
                module ID added to the buffer during the `EpisodeReplayBuffer.add` call.
            module_to_num_steps_added: A dictionary with the number of agent steps per
                module ID added to the buffer during the `EpisodeReplayBuffer.add` call.
            module_to_num_episodes_evicted: A dictionary with the number of episodes per
                module ID evicted to the buffer during the `EpisodeReplayBuffer.add` call.
            module_to_num_steps_evicted: A dictionary with the number of agent steps per
                module ID evicted to the buffer during the `EpisodeReplayBuffer.add` call.
        """
        # Whole buffer episode metrics.
        self.metrics.log_value(
            NUM_EPISODES_STORED,
            self.get_num_episodes(),
            reduce="mean",
            window=self._metrics_num_episodes_for_smoothing,
        )
        # Number of new episodes added. Note, this metric could
        # be zero when ongoing episodes were logged.
        self.metrics.log_value(
            NUM_EPISODES_ADDED,
            num_episodes_added,
            reduce="sum",
            clear_on_reduce=True,
        )
        self.metrics.log_value(
            NUM_EPISODES_ADDED_LIFETIME,
            num_episodes_added,
            reduce="sum",
        )
        self.metrics.log_value(
            NUM_EPISODES_EVICTED,
            num_episodes_evicted,
            reduce="sum",
            clear_on_reduce=True,
        )
        self.metrics.log_value(
            NUM_EPISODES_EVICTED_LIFETIME,
            num_episodes_evicted,
            reduce="sum",
        )
        # Whole buffer step metrics.
        self.metrics.log_value(
            NUM_ENV_STEPS_STORED,
            self.get_num_timesteps(),
            reduce="mean",
            window=self._metrics_num_episodes_for_smoothing,
        )
        self.metrics.log_value(
            NUM_ENV_STEPS_ADDED,
            num_env_steps_added,
            reduce="sum",
            clear_on_reduce=True,
        )
        self.metrics.log_value(
            NUM_ENV_STEPS_ADDED_LIFETIME,
            num_env_steps_added,
            reduce="sum",
        )
        self.metrics.log_value(
            NUM_ENV_STEPS_EVICTED,
            num_env_steps_evicted,
            reduce="sum",
            clear_on_reduce=True,
        )
        self.metrics.log_value(
            NUM_ENV_STEPS_EVICTED_LIFETIME,
            num_env_steps_evicted,
            reduce="sum",
        )

        # Log per-agent metrics.
        for aid in agent_to_num_episodes_added:
            # Number of new episodes added. Note, this metric could
            # be zero.
            self.metrics.log_value(
                (NUM_AGENT_EPISODES_ADDED, aid),
                agent_to_num_episodes_added[aid],
                reduce="sum",
                clear_on_reduce=True,
            )
            self.metrics.log_value(
                (NUM_AGENT_EPISODES_ADDED_LIFETIME, aid),
                agent_to_num_episodes_added[aid],
                reduce="sum",
            )
            # Number of new agent steps added. Note, this metric could
            # be zero, too.
            self.metrics.log_value(
                (NUM_AGENT_STEPS_ADDED, aid),
                agent_to_num_steps_added[aid],
                reduce="sum",
                clear_on_reduce=True,
            )
            self.metrics.log_value(
                (NUM_AGENT_STEPS_ADDED_LIFETIME, aid),
                agent_to_num_steps_added[aid],
                reduce="sum",
            )
        for aid in agent_to_num_episodes_evicted:
            # Number of agent episodes evicted. Note, values could be zero.
            self.metrics.log_value(
                (NUM_AGENT_EPISODES_EVICTED, aid),
                agent_to_num_episodes_evicted[aid],
                reduce="sum",
                clear_on_reduce=True,
            )
            self.metrics.log_value(
                (NUM_AGENT_EPISODES_EVICTED_LIFETIME, aid),
                agent_to_num_episodes_evicted[aid],
                reduce="sum",
            )
            # Number of agent steps evicted. Note, values could be zero.
            self.metrics.log_value(
                (NUM_AGENT_STEPS_EVICTED, aid),
                agent_to_num_steps_evicted[aid],
                reduce="sum",
                clear_on_reduce=True,
            )
            self.metrics.log_value(
                (NUM_AGENT_STEPS_EVICTED_LIFETIME, aid),
                agent_to_num_steps_evicted[aid],
                reduce="sum",
            )
        # Note, we need to loop through the metrics here to receive
        # metrics for all agents (not only the ones that steps).
        for aid in self.metrics.stats[NUM_AGENT_STEPS_ADDED_LIFETIME]:
            # Add default metrics for evicted episodes if not existent.
            if aid not in agent_to_num_episodes_evicted:
                self.metrics.log_value(
                    (NUM_AGENT_EPISODES_EVICTED, aid),
                    0,
                    reduce="sum",
                    clear_on_reduce=True,
                )
                self.metrics.log_value(
                    (NUM_AGENT_EPISODES_EVICTED_LIFETIME, aid),
                    0,
                    reduce="sum",
                )
                self.metrics.log_value(
                    (NUM_AGENT_STEPS_EVICTED, aid),
                    0,
                    reduce="sum",
                    clear_on_reduce=True,
                )
                self.metrics.log_value(
                    (NUM_AGENT_STEPS_EVICTED_LIFETIME, aid),
                    0,
                    reduce="sum",
                )
            # Number of episodes in the buffer.
            agent_episodes_stored = self.metrics.peek(
                (NUM_AGENT_EPISODES_ADDED_LIFETIME, aid)
            ) - self.metrics.peek((NUM_AGENT_EPISODES_EVICTED_LIFETIME, aid))
            self.metrics.log_value(
                (NUM_AGENT_EPISODES_STORED, aid),
                agent_episodes_stored,
                reduce="mean",
                window=self._metrics_num_episodes_for_smoothing,
            )
            # Number of agent steps in the buffer.
            agent_steps_stored = self.metrics.peek(
                (NUM_AGENT_STEPS_ADDED_LIFETIME, aid)
            ) - self.metrics.peek((NUM_AGENT_EPISODES_EVICTED_LIFETIME, aid))

            self.metrics.log_value(
                (NUM_AGENT_STEPS_STORED, aid),
                agent_steps_stored,
                reduce="mean",
                window=self._metrics_num_episodes_for_smoothing,
            )

        # Log per-module metrics.
        for mid in module_to_num_episodes_added:
            # Number of new episodes added. Note, this metric could
            # be zero.
            self.metrics.log_value(
                (NUM_MODULE_EPISODES_ADDED, mid),
                module_to_num_episodes_added[mid],
                reduce="sum",
                clear_on_reduce=True,
            )
            self.metrics.log_value(
                (NUM_MODULE_EPISODES_ADDED_LIFETIME, mid),
                module_to_num_episodes_added[mid],
                reduce="sum",
            )
            # Number of new module steps added. Note, this metric could
            # be zero, too.
            self.metrics.log_value(
                (NUM_MODULE_STEPS_ADDED, mid),
                module_to_num_steps_added[mid],
                reduce="sum",
                clear_on_reduce=True,
            )
            self.metrics.log_value(
                (NUM_MODULE_STEPS_ADDED_LIFETIME, mid),
                module_to_num_steps_added[mid],
                reduce="sum",
            )
        for mid in module_to_num_episodes_evicted:
            # Number of module episodes evicted. Note, values could be zero.
            self.metrics.log_value(
                (NUM_MODULE_EPISODES_EVICTED, mid),
                module_to_num_episodes_evicted[mid],
                reduce="sum",
                clear_on_reduce=True,
            )
            self.metrics.log_value(
                (NUM_MODULE_EPISODES_EVICTED_LIFETIME, mid),
                module_to_num_episodes_evicted[mid],
                reduce="sum",
            )
            # Number of module steps evicted. Note, values could be zero.
            self.metrics.log_value(
                (NUM_MODULE_STEPS_EVICTED, mid),
                module_to_num_steps_evicted[mid],
                reduce="sum",
                clear_on_reduce=True,
            )
            self.metrics.log_value(
                (NUM_MODULE_STEPS_EVICTED_LIFETIME, mid),
                module_to_num_steps_evicted[mid],
                reduce="sum",
            )
        # Note, we need to loop through the metrics here to receive
        # metrics for all agents (not only the ones that steps).
        for mid in self.metrics.stats[NUM_MODULE_STEPS_ADDED_LIFETIME]:
            # Number of episodes in the buffer.
            if mid not in module_to_num_episodes_evicted:
                self.metrics.log_value(
                    (NUM_MODULE_EPISODES_EVICTED, mid),
                    0,
                    reduce="sum",
                    clear_on_reduce=True,
                )
                self.metrics.log_value(
                    (NUM_MODULE_EPISODES_EVICTED_LIFETIME, mid),
                    0,
                    reduce="sum",
                )
                self.metrics.log_value(
                    (NUM_MODULE_STEPS_EVICTED, mid),
                    0,
                    reduce="sum",
                    clear_on_reduce=True,
                )
                self.metrics.log_value(
                    (NUM_MODULE_STEPS_EVICTED_LIFETIME, mid),
                    0,
                    reduce="sum",
                )
            module_episodes_stored = self.metrics.peek(
                (NUM_MODULE_EPISODES_ADDED_LIFETIME, mid)
            ) - self.metrics.peek((NUM_MODULE_EPISODES_EVICTED_LIFETIME, mid))
            self.metrics.log_value(
                (NUM_MODULE_EPISODES_STORED, mid),
                module_episodes_stored,
                reduce="mean",
                window=self._metrics_num_episodes_for_smoothing,
            )

    @override(ReplayBufferInterface)
    def sample(
        self,
        num_items: Optional[int] = None,
        *,
        batch_size_B: Optional[int] = None,
        batch_length_T: Optional[int] = None,
        n_step: Optional[Union[int, Tuple]] = None,
        beta: float = 0.0,
        gamma: float = 0.99,
        include_infos: bool = False,
        include_extra_model_outputs: bool = False,
        sample_episodes: Optional[bool] = False,
        to_numpy: bool = False,
        # TODO (simon): Check, if we need here 1 as default.
        lookback: int = 0,
        min_batch_length_T: int = 0,
        **kwargs,
    ) -> Union[SampleBatchType, SingleAgentEpisode]:
        """Samples from a buffer in a randomized way.

        Each sampled item defines a transition of the form:

        `(o_t, a_t, sum(r_(t+1:t+n+1)), o_(t+n), terminated_(t+n), truncated_(t+n))`

        where `o_t` is drawn by randomized sampling.`n` is defined by the `n_step`
        applied.

        If requested, `info`s of a transitions last timestep `t+n` and respective
        extra model outputs (e.g. action log-probabilities) are added to
        the batch.

        Args:
            num_items: Number of items (transitions) to sample from this
                buffer.
            batch_size_B: The number of rows (transitions) to return in the
                batch
            batch_length_T: THe sequence length to sample. At this point in time
                only sequences of length 1 are possible.
            n_step: The n-step to apply. For the default the batch contains in
                `"new_obs"` the observation and in `"obs"` the observation `n`
                time steps before. The reward will be the sum of rewards
                collected in between these two observations and the action will
                be the one executed n steps before such that we always have the
                state-action pair that triggered the rewards.
                If `n_step` is a tuple, it is considered as a range to sample
                from. If `None`, we use `n_step=1`.
            gamma: The discount factor to be used when applying n-step calculations.
                The default of `0.99` should be replaced by the `Algorithm`s
                discount factor.
            include_infos: A boolean indicating, if `info`s should be included in
                the batch. This could be of advantage, if the `info` contains
                values from the environment important for loss computation. If
                `True`, the info at the `"new_obs"` in the batch is included.
            include_extra_model_outputs: A boolean indicating, if
                `extra_model_outputs` should be included in the batch. This could be
                of advantage, if the `extra_mdoel_outputs`  contain outputs from the
                model important for loss computation and only able to compute with the
                actual state of model e.g. action log-probabilities, etc.). If `True`,
                the extra model outputs at the `"obs"` in the batch is included (the
                timestep at which the action is computed).
            to_numpy: If episodes should be numpy'ized.
            lookback: A desired lookback. Any non-negative integer is valid.
            min_batch_length_T: An optional minimal length when sampling sequences. It
                ensures that sampled sequences are at least `min_batch_length_T` time
                steps long. This can be used to prevent empty sequences during
                learning, when using a burn-in period for stateful `RLModule`s. In rare
                cases, such as when episodes are very short early in training, this may
                result in longer sampling times.

        Returns:
            Either a batch with transitions in each row or (if `return_episodes=True`)
            a list of 1-step long episodes containing all basic episode data and if
            requested infos and extra model outputs.
        """

        if sample_episodes:
            return self._sample_episodes(
                num_items=num_items,
                batch_size_B=batch_size_B,
                batch_length_T=batch_length_T,
                n_step=n_step,
                beta=beta,
                gamma=gamma,
                include_infos=include_infos,
                include_extra_model_outputs=include_extra_model_outputs,
                to_numpy=to_numpy,
                lookback=lookback,
                min_batch_length_T=min_batch_length_T,
            )
        else:
            return self._sample_batch(
                num_items=num_items,
                batch_size_B=batch_size_B,
                batch_length_T=batch_length_T,
            )

    def _sample_batch(
        self,
        num_items: Optional[int] = None,
        *,
        batch_size_B: Optional[int] = None,
        batch_length_T: Optional[int] = None,
    ) -> SampleBatchType:
        """Returns a batch of size B (number of "rows"), where each row has length T.

        Each row contains consecutive timesteps from an episode, but might not start
        at the beginning of that episode. Should an episode end within such a
        row (trajectory), a random next episode (starting from its t0) will be
        concatenated to that row. For more details, see the docstring of the
        EpisodeReplayBuffer class.

        Args:
            num_items: See `batch_size_B`. For compatibility with the
                `ReplayBufferInterface` abstract base class.
            batch_size_B: The number of rows (trajectories) to return in the batch.
            batch_length_T: The length of each row (in timesteps) to return in the
                batch.

        Returns:
            The sampled batch (observations, actions, rewards, terminateds, truncateds)
                of dimensions [B, T, ...].
        """
        if num_items is not None:
            assert batch_size_B is None, (
                "Cannot call `sample()` with both `num_items` and `batch_size_B` "
                "provided! Use either one."
            )
            batch_size_B = num_items

        # Use our default values if no sizes/lengths provided.
        batch_size_B = batch_size_B or self.batch_size_B
        batch_length_T = batch_length_T or self.batch_length_T

        # Rows to return.
        observations = [[] for _ in range(batch_size_B)]
        actions = [[] for _ in range(batch_size_B)]
        rewards = [[] for _ in range(batch_size_B)]
        is_first = [[False] * batch_length_T for _ in range(batch_size_B)]
        is_last = [[False] * batch_length_T for _ in range(batch_size_B)]
        is_terminated = [[False] * batch_length_T for _ in range(batch_size_B)]
        is_truncated = [[False] * batch_length_T for _ in range(batch_size_B)]

        # Store the unique episode buffer indices to determine sample variation.
        sampled_episode_idxs = set()
        # Store the unique env step buffer indices to determine sample variation
        sampled_env_step_idxs = set()

        B = 0
        T = 0
        while B < batch_size_B:
            # Pull a new uniform random index tuple: (eps_idx, ts_in_eps_idx).
            index_tuple = self._indices[self.rng.integers(len(self._indices))]

            # Compute the actual episode index (offset by the number of
            # already evicted episodes).
            episode_idx, episode_ts = (
                index_tuple[0] - self._num_episodes_evicted,
                index_tuple[1],
            )
            episode = self.episodes[episode_idx]

            # Starting a new chunk, set is_first to True.
            is_first[B][T] = True

            # Begin of new batch item (row).
            if len(rewards[B]) == 0:
                # And we are at the start of an episode: Set reward to 0.0.
                if episode_ts == 0:
                    rewards[B].append(0.0)
                # We are in the middle of an episode: Set reward to the previous
                # timestep's values.
                else:
                    rewards[B].append(episode.rewards[episode_ts - 1])
            # We are in the middle of a batch item (row). Concat next episode to this
            # row from the next episode's beginning. In other words, we never concat
            # a middle of an episode to another truncated one.
            else:
                episode_ts = 0
                rewards[B].append(0.0)

            observations[B].extend(episode.observations[episode_ts:])
            # Repeat last action to have the same number of actions than observations.
            actions[B].extend(episode.actions[episode_ts:])
            actions[B].append(episode.actions[-1])
            # Number of rewards are also the same as observations b/c we have the
            # initial 0.0 one.
            rewards[B].extend(episode.rewards[episode_ts:])
            assert len(observations[B]) == len(actions[B]) == len(rewards[B])

            T = min(len(observations[B]), batch_length_T)

            # Set is_last=True.
            is_last[B][T - 1] = True
            # If episode is terminated and we have reached the end of it, set
            # is_terminated=True.
            if episode.is_terminated and T == len(observations[B]):
                is_terminated[B][T - 1] = True
            # If episode is truncated and we have reached the end of it, set
            # is_truncated=True.
            elif episode.is_truncated and T == len(observations[B]):
                is_truncated[B][T - 1] = True

            # We are done with this batch row.
            if T == batch_length_T:
                # We may have overfilled this row: Clip trajectory at the end.
                observations[B] = observations[B][:batch_length_T]
                actions[B] = actions[B][:batch_length_T]
                rewards[B] = rewards[B][:batch_length_T]
                # Start filling the next row.
                B += 1
                T = 0
            # Add the episode buffer index to the set of episode indexes.
            sampled_episode_idxs.add(episode_idx)
            # Add the unique hashcode for the timestep.
            sampled_env_step_idxs.add(
                hashlib.sha256(f"{episode.id_}-{episode_ts}".encode()).hexdigest()
            )

        # Update our sampled counter.
        self.sampled_timesteps += batch_size_B * batch_length_T

        # Update the sample metrics.
        num_env_steps_sampled = batch_size_B * batch_length_T
        num_episodes_per_sample = len(sampled_episode_idxs)
        num_env_steps_per_sample = len(sampled_env_step_idxs)
        num_resamples = 0
        agent_to_sample_size = {DEFAULT_AGENT_ID: num_env_steps_sampled}
        agent_to_num_episodes_per_sample = {DEFAULT_AGENT_ID: num_episodes_per_sample}
        agent_to_num_steps_per_sample = {DEFAULT_AGENT_ID: num_env_steps_per_sample}
        agent_to_num_resamples = {DEFAULT_AGENT_ID: num_resamples}
        module_to_num_steps_sampled = {DEFAULT_MODULE_ID: num_env_steps_sampled}
        module_to_num_episodes_per_sample = {DEFAULT_MODULE_ID: num_episodes_per_sample}
        module_to_num_steps_per_sample = {DEFAULT_MODULE_ID: num_env_steps_per_sample}
        module_to_num_resamples = {DEFAULT_MODULE_ID: num_resamples}
        self._update_sample_metrics(
            num_env_steps_sampled=num_env_steps_sampled,
            num_episodes_per_sample=num_episodes_per_sample,
            num_env_steps_per_sample=num_env_steps_per_sample,
            num_resamples=num_resamples,
            sampled_n_step=None,
            agent_to_num_steps_sampled=agent_to_sample_size,
            agent_to_num_episodes_per_sample=agent_to_num_episodes_per_sample,
            agent_to_num_steps_per_sample=agent_to_num_steps_per_sample,
            agent_to_num_resamples=agent_to_num_resamples,
            agent_to_sampled_n_step=None,
            module_to_num_steps_sampled=module_to_num_steps_sampled,
            module_to_num_episodes_per_sample=module_to_num_episodes_per_sample,
            module_to_num_steps_per_sample=module_to_num_steps_per_sample,
            module_to_num_resamples=module_to_num_resamples,
            module_to_sampled_n_step=None,
        )

        # TODO: Return SampleBatch instead of this simpler dict.
        ret = {
            "obs": np.array(observations),
            "actions": np.array(actions),
            "rewards": np.array(rewards),
            "is_first": np.array(is_first),
            "is_last": np.array(is_last),
            "is_terminated": np.array(is_terminated),
            "is_truncated": np.array(is_truncated),
        }

        return ret

    def _sample_episodes(
        self,
        num_items: Optional[int] = None,
        *,
        batch_size_B: Optional[int] = None,
        batch_length_T: Optional[int] = None,
        n_step: Optional[Union[int, Tuple]] = None,
        gamma: float = 0.99,
        include_infos: bool = False,
        include_extra_model_outputs: bool = False,
        to_numpy: bool = False,
        lookback: int = 1,
        min_batch_length_T: int = 0,
        **kwargs,
    ) -> List[SingleAgentEpisode]:
        """Samples episodes from a buffer in a randomized way.

        Each sampled item defines a transition of the form:

        `(o_t, a_t, sum(r_(t+1:t+n+1)), o_(t+n), terminated_(t+n), truncated_(t+n))`

        where `o_t` is drawn by randomized sampling.`n` is defined by the `n_step`
        applied.

        If requested, `info`s of a transitions last timestep `t+n` and respective
        extra model outputs (e.g. action log-probabilities) are added to
        the batch.

        Args:
            num_items: Number of items (transitions) to sample from this
                buffer.
            batch_size_B: The number of rows (transitions) to return in the
                batch
            batch_length_T: The sequence length to sample. Can be either `None`
                (the default) or any positive integer.
            n_step: The n-step to apply. For the default the batch contains in
                `"new_obs"` the observation and in `"obs"` the observation `n`
                time steps before. The reward will be the sum of rewards
                collected in between these two observations and the action will
                be the one executed n steps before such that we always have the
                state-action pair that triggered the rewards.
                If `n_step` is a tuple, it is considered as a range to sample
                from. If `None`, we use `n_step=1`.
            gamma: The discount factor to be used when applying n-step calculations.
                The default of `0.99` should be replaced by the `Algorithm`s
                discount factor.
            include_infos: A boolean indicating, if `info`s should be included in
                the batch. This could be of advantage, if the `info` contains
                values from the environment important for loss computation. If
                `True`, the info at the `"new_obs"` in the batch is included.
            include_extra_model_outputs: A boolean indicating, if
                `extra_model_outputs` should be included in the batch. This could be
                of advantage, if the `extra_mdoel_outputs`  contain outputs from the
                model important for loss computation and only able to compute with the
                actual state of model e.g. action log-probabilities, etc.). If `True`,
                the extra model outputs at the `"obs"` in the batch is included (the
                timestep at which the action is computed).
            to_numpy: If episodes should be numpy'ized.
            lookback: A desired lookback. Any non-negative integer is valid.
            min_batch_length_T: An optional minimal length when sampling sequences. It
                ensures that sampled sequences are at least `min_batch_length_T` time
                steps long. This can be used to prevent empty sequences during
                learning, when using a burn-in period for stateful `RLModule`s. In rare
                cases, such as when episodes are very short early in training, this may
                result in longer sampling times.

        Returns:
            A list of 1-step long episodes containing all basic episode data and if
            requested infos and extra model outputs.
        """
        if num_items is not None:
            assert batch_size_B is None, (
                "Cannot call `sample()` with both `num_items` and `batch_size_B` "
                "provided! Use either one."
            )
            batch_size_B = num_items

        # Use our default values if no sizes/lengths provided.
        batch_size_B = batch_size_B or self.batch_size_B

        assert n_step is not None, (
            "When sampling episodes, `n_step` must be "
            "provided, but `n_step` is `None`."
        )
        # If no sequence should be sampled, we sample n-steps.
        if not batch_length_T:
            # Sample the `n_step`` itself, if necessary.
            actual_n_step = n_step
            random_n_step = isinstance(n_step, tuple)
        # Otherwise we use an n-step of 1.
        else:
            assert (
                not isinstance(n_step, tuple) and n_step == 1
            ), "When sampling sequences n-step must be 1."
            actual_n_step = n_step

        # Keep track of the indices that were sampled last for updating the
        # weights later (see `ray.rllib.utils.replay_buffer.utils.
        # update_priorities_in_episode_replay_buffer`).
        self._last_sampled_indices = []

        sampled_episodes = []
        # Record all the env step buffer indices that are contained in the sample.
        sampled_env_step_idxs = set()
        # Record all the episode buffer indices that are contained in the sample.
        sampled_episode_idxs = set()
        # Record all n-steps that have been used.
        sampled_n_steps = []
        # Record all the env step buffer indices that are contained in the sample.
        sampled_env_step_idxs = set()
        # Record all the episode buffer indices that are contained in the sample.
        sampled_episode_idxs = set()
        # Record all n-steps that have been used.
        sampled_n_steps = []

        B = 0
        while B < batch_size_B:
            # Pull a new uniform random index tuple: (eps_idx, ts_in_eps_idx).
            index_tuple = self._indices[self.rng.integers(len(self._indices))]

            # Compute the actual episode index (offset by the number of
            # already evicted episodes).
            episode_idx, episode_ts = (
                index_tuple[0] - self._num_episodes_evicted,
                index_tuple[1],
            )
            episode = self.episodes[episode_idx]

            # If we use random n-step sampling, draw the n-step for this item.
            if not batch_length_T and random_n_step:
                actual_n_step = int(self.rng.integers(n_step[0], n_step[1]))

            # Skip, if we are too far to the end and `episode_ts` + n_step would go
            # beyond the episode's end.
            if min_batch_length_T > 0 and episode_ts + min_batch_length_T >= len(
                episode
            ):
                continue
            if episode_ts + (batch_length_T or 0) + (actual_n_step - 1) > len(episode):
                actual_length = len(episode)
            else:
                actual_length = episode_ts + (batch_length_T or 0) + (actual_n_step - 1)

            # If no sequence should be sampled, we sample here the n-step.
            if not batch_length_T:
                sampled_episode = episode.slice(
                    slice(
                        episode_ts,
                        episode_ts + actual_n_step,
                    )
                )
                # Note, this will be the reward after executing action
                # `a_(episode_ts-n_step+1)`. For `n_step>1` this will be the discounted
                # sum of all discounted rewards that were collected over the last n
                # steps.
                raw_rewards = sampled_episode.get_rewards()

                rewards = scipy.signal.lfilter(
                    [1], [1, -gamma], raw_rewards[::-1], axis=0
                )[-1]

                sampled_episode = SingleAgentEpisode(
                    id_=sampled_episode.id_,
                    agent_id=sampled_episode.agent_id,
                    module_id=sampled_episode.module_id,
                    observation_space=sampled_episode.observation_space,
                    action_space=sampled_episode.action_space,
                    observations=[
                        sampled_episode.get_observations(0),
                        sampled_episode.get_observations(-1),
                    ],
                    actions=[sampled_episode.get_actions(0)],
                    rewards=[rewards],
                    infos=[
                        sampled_episode.get_infos(0),
                        sampled_episode.get_infos(-1),
                    ],
                    terminated=sampled_episode.is_terminated,
                    truncated=sampled_episode.is_truncated,
                    extra_model_outputs={
                        **(
                            {
                                k: [episode.get_extra_model_outputs(k, 0)]
                                for k in episode.extra_model_outputs.keys()
                            }
                            if include_extra_model_outputs
                            else {}
                        ),
                    },
                    t_started=episode_ts,
                    len_lookback_buffer=0,
                )
            # Otherwise we simply slice the episode.
            else:
                sampled_episode = episode.slice(
                    slice(
                        episode_ts,
                        actual_length,
                    ),
                    len_lookback_buffer=lookback,
                )
            # Record a has for the episode ID and timestep inside of the episode.
            sampled_env_step_idxs.add(
                hashlib.sha256(f"{episode.id_}-{episode_ts}".encode()).hexdigest()
            )
            # Record a has for the episode ID and timestep inside of the episode.
            sampled_env_step_idxs.add(
                hashlib.sha256(f"{episode.id_}-{episode_ts}".encode()).hexdigest()
            )
            # Remove reference to sampled episode.
            del episode

            # Add the actually chosen n-step in this episode.
            sampled_episode.extra_model_outputs["n_step"] = InfiniteLookbackBuffer(
                np.full((len(sampled_episode) + lookback,), actual_n_step),
                lookback=lookback,
            )
            # Some loss functions need `weights` - which are only relevant when
            # prioritizing.
            sampled_episode.extra_model_outputs["weights"] = InfiniteLookbackBuffer(
                np.ones((len(sampled_episode) + lookback,)), lookback=lookback
            )

            # Append the sampled episode.
            sampled_episodes.append(sampled_episode)
            sampled_episode_idxs.add(episode_idx)
            sampled_n_steps.append(actual_n_step)

            # Increment counter.
            B += (actual_length - episode_ts - (actual_n_step - 1) + 1) or 1

        # Update the metric.
        self.sampled_timesteps += batch_size_B

        # Update the sample metrics.
        num_env_steps_sampled = batch_size_B
        num_episodes_per_sample = len(sampled_episode_idxs)
        num_env_steps_per_sample = len(sampled_env_step_idxs)
        sampled_n_step = sum(sampled_n_steps) / batch_size_B
        num_resamples = 0
        agent_to_num_steps_sampled = {DEFAULT_AGENT_ID: num_env_steps_sampled}
        agent_to_num_episodes_per_sample = {DEFAULT_AGENT_ID: num_episodes_per_sample}
        agent_to_num_steps_per_sample = {DEFAULT_AGENT_ID: num_env_steps_per_sample}
        agent_to_sampled_n_step = {DEFAULT_AGENT_ID: sampled_n_step}
        agent_to_num_resamples = {DEFAULT_AGENT_ID: num_resamples}
        module_to_num_steps_sampled = {DEFAULT_MODULE_ID: num_env_steps_sampled}
        module_to_num_episodes_per_sample = {DEFAULT_MODULE_ID: num_episodes_per_sample}
        module_to_num_steps_per_sample = {DEFAULT_MODULE_ID: num_env_steps_per_sample}
        module_to_sampled_n_step = {DEFAULT_MODULE_ID: sampled_n_step}
        module_to_num_resamples = {DEFAULT_MODULE_ID: num_resamples}
        self._update_sample_metrics(
            num_env_steps_sampled=num_env_steps_sampled,
            num_episodes_per_sample=num_episodes_per_sample,
            num_env_steps_per_sample=num_env_steps_per_sample,
            sampled_n_step=sampled_n_step,
            num_resamples=num_resamples,
            agent_to_num_steps_sampled=agent_to_num_steps_sampled,
            agent_to_num_episodes_per_sample=agent_to_num_episodes_per_sample,
            agent_to_num_steps_per_sample=agent_to_num_steps_per_sample,
            agent_to_sampled_n_step=agent_to_sampled_n_step,
            agent_to_num_resamples=agent_to_num_resamples,
            module_to_num_steps_sampled=module_to_num_steps_sampled,
            module_to_num_episodes_per_sample=module_to_num_episodes_per_sample,
            module_to_num_steps_per_sample=module_to_num_steps_per_sample,
            module_to_sampled_n_step=module_to_sampled_n_step,
            module_to_num_resamples=module_to_num_resamples,
        )

        return sampled_episodes

    @OverrideToImplementCustomLogic_CallToSuperRecommended
    def _update_sample_metrics(
        self,
        *,
        num_env_steps_sampled: int,
        num_episodes_per_sample: int,
        num_env_steps_per_sample: int,
        sampled_n_step: Optional[float],
        num_resamples: int,
        agent_to_num_steps_sampled: Dict[AgentID, int],
        agent_to_num_episodes_per_sample: Dict[AgentID, int],
        agent_to_num_steps_per_sample: Dict[AgentID, int],
        agent_to_sampled_n_step: Dict[AgentID, float],
        agent_to_num_resamples: Dict[AgentID, int],
        module_to_num_steps_sampled: Dict[ModuleID, int],
        module_to_num_episodes_per_sample: Dict[ModuleID, int],
        module_to_num_steps_per_sample: Dict[ModuleID, int],
        module_to_sampled_n_step: Dict[ModuleID, float],
        module_to_num_resamples: Dict[ModuleID, int],
        **kwargs: Dict[str, Any],
    ) -> None:
        """Updates the replay buffer's sample metrics.

        Args:
            num_env_steps_sampled: The number of environment steps sampled
                this iteration in the `sample` method.
            num_episodes_per_sample: The number of unique episodes in the
                sample.
            num_env_steps_per_sample: The number of unique environment steps
                in the sample.
            sampled_n_step: The mean n-step used in the sample. Note, this
                is constant, if the n-step is not sampled.
            num_resamples: The number of resamples in a single call to
                `PrioritizedEpisodeReplayBuffer.sample`. A resampling is triggered
                when the sampled timestep is to near to the episode end to cover the
                required n-step.
            agent_to_num_steps_sampled: A dictionary with the number of agent
                steps per agent ID sampled during the `EpisodeReplayBuffer.sample`
                call.
            agent_to_num_episodes_per_sample: A dictionary with the number of
                unique episodes per agent ID contained in the sample returned by
                the `EpisodeReplayBuffer.sample` call.
            agent_to_num_steps_per_sample: A dictionary with the number of
                unique agent steps per agent ID contained in the sample returned by
                the `EpisodeReplayBuffer.sample` call.
            agent_to_sampled_n_step: A dictionary with the mean n-step per agent ID
                used in the sample returned by the `EpisodeReplayBuffer.sample` call.
            agent_to_num_resamples: A dictionary with the number of resamples per
                agent ID in a single call to `PrioritizedEpisodeReplayBuffer.sample`.
                A resampling is triggered when the sampled timestep is to near to the
                episode end to cover the required n-step.
            module_to_num_steps_sampled: A dictionary with the number of module
                steps per module ID sampled during the `EpisodeReplayBuffer.sample`
                call.
            module_to_num_episodes_per_sample: A dictionary with the number of
                unique episodes per module ID contained in the sample returned by
                the `EpisodeReplayBuffer.sample` call.
            module_to_num_steps_per_sample: A dictionary with the number of
                unique module steps per module ID contained in the sample returned by
                the `EpisodeReplayBuffer.sample` call.
            module_to_sampled_n_step: A dictionary with the mean n-step per module ID
                used in the sample returned by the `EpisodeReplayBuffer.sample` call.
            module_to_num_resamples: A dictionary with the number of resamples per
                module ID in a single call to `PrioritizedEpisodeReplayBuffer.sample`.
                A resampling is triggered when the sampled timestep is to near to the
                episode end to cover the required n-step.
        """
        # Whole buffer sampled env steps metrics.
        self.metrics.log_value(
            NUM_EPISODES_PER_SAMPLE,
            num_episodes_per_sample,
            reduce="sum",
            clear_on_reduce=True,
        )
        self.metrics.log_value(
            NUM_ENV_STEPS_PER_SAMPLE,
            num_env_steps_per_sample,
            reduce="sum",
            clear_on_reduce=True,
        )
        self.metrics.log_value(
            NUM_ENV_STEPS_PER_SAMPLE_LIFETIME,
            num_env_steps_per_sample,
            reduce="sum",
        )
        self.metrics.log_value(
            NUM_ENV_STEPS_SAMPLED,
            num_env_steps_sampled,
            reduce="sum",
            clear_on_reduce=True,
        )
        self.metrics.log_value(
            NUM_ENV_STEPS_SAMPLED_LIFETIME,
            num_env_steps_sampled,
            reduce="sum",
        )
        self.metrics.log_value(
            ENV_STEP_UTILIZATION,
            self.metrics.peek(NUM_ENV_STEPS_PER_SAMPLE_LIFETIME)
            / self.metrics.peek(NUM_ENV_STEPS_SAMPLED_LIFETIME),
            reduce="mean",
            window=self._metrics_num_episodes_for_smoothing,
        )
        if sampled_n_step:
            self.metrics.log_value(
                ACTUAL_N_STEP,
                sampled_n_step,
                reduce="mean",
                window=self._metrics_num_episodes_for_smoothing,
            )
            for aid in agent_to_sampled_n_step:
                self.metrics.log_value(
                    (AGENT_ACTUAL_N_STEP, aid),
                    agent_to_sampled_n_step[aid],
                    reduce="mean",
                    window=self._metrics_num_episodes_for_smoothing,
                )
            for mid in module_to_sampled_n_step:
                self.metrics.log_value(
                    (MODULE_ACTUAL_N_STEP, mid),
                    module_to_sampled_n_step[mid],
                    reduce="mean",
                    window=self._metrics_num_episodes_for_smoothing,
                )

        self.metrics.log_value(
            NUM_RESAMPLES,
            num_resamples,
            reduce="sum",
            clear_on_reduce=True,
        )

        # Add per-agent metrics.
        for aid in agent_to_num_steps_sampled:
            self.metrics.log_value(
                (NUM_AGENT_STEPS_SAMPLED, aid),
                agent_to_num_steps_sampled[aid],
                reduce="sum",
                clear_on_reduce=True,
            )
            # TODO (simon): Check, if we can then deprecate
            # self.sampled_timesteps.
            self.metrics.log_value(
                (NUM_AGENT_STEPS_SAMPLED_LIFETIME, aid),
                agent_to_num_steps_sampled[aid],
                reduce="sum",
            )
            self.metrics.log_value(
                (NUM_AGENT_EPISODES_PER_SAMPLE, aid),
                agent_to_num_episodes_per_sample[aid],
                reduce="sum",
                clear_on_reduce=True,
            )
            self.metrics.log_value(
                (NUM_AGENT_STEPS_PER_SAMPLE, aid),
                agent_to_num_steps_per_sample[aid],
                reduce="sum",
                clear_on_reduce=True,
            )
            self.metrics.log_value(
                (NUM_AGENT_STEPS_PER_SAMPLE_LIFETIME, aid),
                agent_to_num_steps_per_sample[aid],
                reduce="sum",
            )
            self.metrics.log_value(
                (AGENT_STEP_UTILIZATION, aid),
                self.metrics.peek((NUM_AGENT_STEPS_PER_SAMPLE_LIFETIME, aid))
                / self.metrics.peek((NUM_AGENT_STEPS_SAMPLED_LIFETIME, aid)),
                reduce="mean",
                window=self._metrics_num_episodes_for_smoothing,
            )
            self.metrics.log_value(
                (NUM_AGENT_RESAMPLES, aid),
                agent_to_num_resamples[aid],
                reduce="sum",
                clear_on_reduce=True,
            )
        # Add per-module metrics.
        for mid in module_to_num_steps_sampled:
            self.metrics.log_value(
                (NUM_MODULE_STEPS_SAMPLED, mid),
                module_to_num_steps_sampled[mid],
                reduce="sum",
                clear_on_reduce=True,
            )
            # TODO (simon): Check, if we can then deprecate
            # self.sampled_timesteps.
            self.metrics.log_value(
                (NUM_MODULE_STEPS_SAMPLED_LIFETIME, mid),
                module_to_num_steps_sampled[mid],
                reduce="sum",
            )
            self.metrics.log_value(
                (NUM_MODULE_EPISODES_PER_SAMPLE, mid),
                module_to_num_episodes_per_sample[mid],
                reduce="sum",
                clear_on_reduce=True,
            )
            self.metrics.log_value(
                (NUM_MODULE_STEPS_PER_SAMPLE, mid),
                module_to_num_steps_per_sample[mid],
                reduce="sum",
                clear_on_reduce=True,
            )
            self.metrics.log_value(
                (NUM_MODULE_STEPS_PER_SAMPLE_LIFETIME, mid),
                module_to_num_steps_per_sample[mid],
                reduce="sum",
            )
            self.metrics.log_value(
                (MODULE_STEP_UTILIZATION, mid),
                self.metrics.peek((NUM_MODULE_STEPS_PER_SAMPLE_LIFETIME, mid))
                / self.metrics.peek((NUM_MODULE_STEPS_SAMPLED_LIFETIME, mid)),
                reduce="mean",
                window=self._metrics_num_episodes_for_smoothing,
            )
            self.metrics.log_value(
                (NUM_MODULE_RESAMPLES, mid),
                module_to_num_resamples[mid],
                reduce="sum",
                clear_on_reduce=True,
            )

    # TODO (simon): Check, if we can instead peek into the metrics
    # and deprecate all variables.
    def get_num_episodes(self, module_id: Optional[ModuleID] = None) -> int:
        """Returns number of episodes (completed or truncated) stored in the buffer."""
        return len(self.episodes)

    def get_num_episodes_evicted(self, module_id: Optional[ModuleID] = None) -> int:
        """Returns number of episodes that have been evicted from the buffer."""
        return self._num_episodes_evicted

    def get_num_timesteps(self, module_id: Optional[ModuleID] = None) -> int:
        """Returns number of individual timesteps stored in the buffer."""
        return len(self._indices)

    def get_sampled_timesteps(self, module_id: Optional[ModuleID] = None) -> int:
        """Returns number of timesteps that have been sampled in buffer's lifetime."""
        return self.sampled_timesteps

    def get_added_timesteps(self, module_id: Optional[ModuleID] = None) -> int:
        """Returns number of timesteps that have been added in buffer's lifetime."""
        return self._num_timesteps_added

    def get_metrics(self) -> ResultDict:
        """Returns the metrics of the buffer and reduces them."""
        return self.metrics.reduce()

    @override(ReplayBufferInterface)
    def get_state(self) -> Dict[str, Any]:
        """Gets a pickable state of the buffer.

        This is used for checkpointing the buffer's state. It is specifically helpful,
        for example, when a trial is paused and resumed later on. The buffer's state
        can be saved to disk and reloaded when the trial is resumed.

        Returns:
            A dict containing all necessary information to restore the buffer's state.
        """
        return {
            "episodes": [eps.get_state() for eps in self.episodes],
            "episode_id_to_index": list(self.episode_id_to_index.items()),
            "_num_episodes_evicted": self._num_episodes_evicted,
            "_indices": self._indices,
            "_num_timesteps": self._num_timesteps,
            "_num_timesteps_added": self._num_timesteps_added,
            "sampled_timesteps": self.sampled_timesteps,
        }

    @override(ReplayBufferInterface)
    def set_state(self, state) -> None:
        """Sets the state of a buffer from a previously stored state.

        See `get_state()` for more information on what is stored in the state. This
        method is used to restore the buffer's state from a previously stored state.
        It is specifically helpful, for example, when a trial is paused and resumed
        later on. The buffer's state can be saved to disk and reloaded when the trial
        is resumed.

        Args:
            state: The state to restore the buffer from.
        """
        self._set_episodes(state)
        self.episode_id_to_index = dict(state["episode_id_to_index"])
        self._num_episodes_evicted = state["_num_episodes_evicted"]
        self._indices = state["_indices"]
        self._num_timesteps = state["_num_timesteps"]
        self._num_timesteps_added = state["_num_timesteps_added"]
        self.sampled_timesteps = state["sampled_timesteps"]

    def _set_episodes(self, state) -> None:
        """Sets the episodes from the state.

        Note, this method is used for class inheritance purposes. It is specifically
        helpful when a subclass of this class wants to override the behavior of how
        episodes are set from the state. By default, it sets `SingleAgentEpuisode`s,
        but subclasses can override this method to set episodes of a different type.
        """
        if not self.episodes:
            self.episodes = deque(
                [
                    SingleAgentEpisode.from_state(eps_data)
                    for eps_data in state["episodes"]
                ]
            )
